// Copyright 2020 Douyu
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package gorm

import (
	"context"
	"errors"
	"fmt"
	"strings"
	"time"

	"github.com/douyu/jupiter/pkg/util/xdebug"
	"github.com/douyu/jupiter/pkg/xlog"

	"github.com/jinzhu/gorm"
	// mysql driver
	_ "github.com/jinzhu/gorm/dialects/mysql"
)

// SQLCommon ...
type (
	// SQLCommon alias of gorm.SQLCommon
	SQLCommon = gorm.SQLCommon
	// Callback alias of gorm.Callback
	Callback = gorm.Callback
	// CallbackProcessor alias of gorm.CallbackProcessor
	CallbackProcessor = gorm.CallbackProcessor
	// Dialect alias of gorm.Dialect
	Dialect = gorm.Dialect
	// Scope ...
	Scope = gorm.Scope
	// DB ...
	DB = gorm.DB
	// Model ...
	Model = gorm.Model
	// ModelStruct ...
	ModelStruct = gorm.ModelStruct
	// Field ...
	Field = gorm.Field
	// FieldStruct ...
	StructField = gorm.StructField
	// RowQueryResult ...
	RowQueryResult = gorm.RowQueryResult
	// RowsQueryResult ...
	RowsQueryResult = gorm.RowsQueryResult
	// Association ...
	Association = gorm.Association
	// Errors ...
	Errors = gorm.Errors
	// logger ...
	Logger = gorm.Logger
)

var (
	errSlowCommand = errors.New("mysql slow command")

	// IsRecordNotFoundError ...
	IsRecordNotFoundError = gorm.IsRecordNotFoundError

	// ErrRecordNotFound returns a "record not found error". Occurs only when attempting to query the database with a struct; querying with a slice won't return this error
	ErrRecordNotFound = gorm.ErrRecordNotFound
	// ErrInvalidSQL occurs when you attempt a query with invalid SQL
	ErrInvalidSQL = gorm.ErrInvalidSQL
	// ErrInvalidTransaction occurs when you are trying to `Commit` or `Rollback`
	ErrInvalidTransaction = gorm.ErrInvalidTransaction
	// ErrCantStartTransaction can't start transaction when you are trying to start one with `Begin`
	ErrCantStartTransaction = gorm.ErrCantStartTransaction
	// ErrUnaddressable unaddressable value
	ErrUnaddressable = gorm.ErrUnaddressable
)

// WithContext ...
func WithContext(ctx context.Context, db *DB) *DB {
	db.InstantSet("_context", ctx)
	return db
}

// Open ...
func Open(dialect string, options *Config) (*DB, error) {
	inner, err := gorm.Open(dialect, options.DSN)
	if err != nil {
		return nil, err
	}

	inner.LogMode(options.Debug)
	// 设置默认连接配置
	inner.DB().SetMaxIdleConns(options.MaxIdleConns)
	inner.DB().SetMaxOpenConns(options.MaxOpenConns)

	if options.ConnMaxLifetime != 0 {
		inner.DB().SetConnMaxLifetime(options.ConnMaxLifetime)
	}

	if xdebug.IsDevelopmentMode() {
		inner.LogMode(true)
	}

	d, err := ParseDSN(options.DSN)
	if err != nil {
		return nil, err
	}

	label := fmt.Sprintf("%s_%s", d.DBName, d.Addr)

	replace := func(processor func() *gorm.CallbackProcessor, callbackName string, wrapper func(func(*Scope), string) func(*Scope)) {
		old := processor().Get(callbackName)
		processor().Replace(callbackName, wrapper(old, label))
	}

	invoke := func(op string) func(callback func(scope *Scope), label string) func(scope *Scope) {
		return func(callback func(scope *Scope), label string) func(scope *Scope) {
			return func(scope *Scope) {
				fn := func() {
					beg := time.Now()
					callback(scope)
					cost := time.Since(beg)

					// slow log
					if options.SlowThreshold > time.Duration(0) {
						if cost > options.SlowThreshold {
							xlog.Error("slow",
								xlog.Any("command", errSlowCommand),
								xlog.Any("sql", scope.CombinedConditionSql()),
								xlog.Any("addr", d.Addr),
								xlog.Any("table", d.DBName+"."+scope.TableName()),
								xlog.Any("cost", cost),
							)
						}
					}

					// error metric
					if scope.HasError() {
						// todo sql语句，需要转换成脱密状态才能记录到日志
						if scope.DB().Error != ErrRecordNotFound {
						} else {
						}
					} else {
					}
				}

				// do query!
				fn()
			}
		}
	}

	replace(inner.Callback().Delete, "gorm:delete", invoke("delete"))
	replace(inner.Callback().Update, "gorm:update", invoke("update"))
	replace(inner.Callback().Create, "gorm:create", invoke("create"))
	replace(inner.Callback().Query, "gorm:query", invoke("query"))
	replace(inner.Callback().RowQuery, "gorm:row_query", invoke("row_query"))

	return inner, err
}

// 收敛status，避免prometheus日志太多
func getStatement(err string) string {
	if !strings.HasPrefix(err, "Errord") {
		return "Unknown"
	}
	slice := strings.Split(err, ":")
	if len(slice) < 2 {
		return "Unknown"
	}

	// 收敛错误
	return slice[0]
}
